!pip install /kaggle/input/dicomsdl/dicomsdl-0.109.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
Processing /kaggle/input/dicomsdl/dicomsdl-0.109.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl Installing collected packages: dicomsdl Successfully installed dicomsdl-0.109.2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
libdir = '.'
valid_bs_ = 4
num_workers_ = 2
from glob import glob
class CFG:
seed=42
device='GPU'
num_workers=num_workers_
valid_bs=valid_bs_
mode = 'test'
data_dir = '/kaggle/input/rsna-2023-abdominal-trauma-detection/'
img_dir = data_dir + f'{mode}_images'
target_cols=['bowel_healthy', 'bowel_injury', 'extravasation_healthy',
'extravasation_injury', 'kidney_healthy', 'kidney_low', 'kidney_high',
'liver_healthy', 'liver_low', 'liver_high', 'spleen_healthy',
'spleen_low', 'spleen_high']
archs_list = ["resnest50d"]*5
weights_list = glob('/kaggle/input/rsna-gru-aug3c/*.pth')
archs_list2 = ["resnest50d"]*5
weights_list2 = glob('/kaggle/input/rsna-kidney-mxaug3c/*.pth')
seq_len = 96
img_size = 256
dropout=0.1
import sys;
package_paths = [f'{libdir}pytorch-image-models-master']
for pth in package_paths:
sys.path.append(pth)
import ast
from glob import glob
import cv2
# from skimage import io
import os
from datetime import datetime
import time
import random
from tqdm import tqdm
from contextlib import contextmanager
import math
import dicomsdl
import numpy as np
import pandas as pd
import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedGroupKFold
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import timm
import warnings
import joblib
from scipy.ndimage.interpolation import zoom
import nibabel as nib
import pydicom as dicom
import gc
from torch.nn import DataParallel
from albumentations import Resize
import albumentations
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(CFG.seed)
def do_pad_to_square(image):
l, h, w = image.shape
if w > h:
pad = w - h
pad0 = pad // 2
pad1 = pad - pad0
image = F.pad(image, [0, 0, pad0, pad1], mode='constant', value=0)
if w < h:
pad = h - w
pad0 = pad // 2
pad1 = pad - pad0
image = F.pad(image, [pad0, pad1, 0, 0], mode='constant', value=0)
return image
def do_scale_to_size(image, spacing, max_size):
dz, dy, dx = spacing
l, s, s = image.shape # scale to max size
if max_size != s:
scale = max_size / s
l = int(dz / dy * l * 0.5) # we use sapcing dz,dy,dx = 2,1,1
l = int(scale * l)
h = int(scale * s)
w = int(scale * s)
image = F.interpolate(
image.unsqueeze(0).unsqueeze(0),
size=(l, h, w),
mode='trilinear',
align_corners=False,
).squeeze(0).squeeze(0)
return image
def dicomsdl_to_numpy_image(ds, index=0):
info = ds.getPixelDataInfo()
if info['SamplesPerPixel'] != 1:
raise RuntimeError('SamplesPerPixel != 1') # number of separate planes in this image
shape = [info['Rows'], info['Cols']]
dtype = info['dtype']
outarr = np.empty(shape, dtype=dtype)
ds.copyFrameData(index, outarr)
return outarr
def load_dicomsdl_dir(dcm_dir, slice_range=None):
dcm_file = sorted(glob(f'{dcm_dir}/*.dcm'), key=lambda x: int(x.split('/')[-1].split('.')[0]))
# https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/435815
#fake some slice so that it won't cause error ....
if len(dcm_file)==1:
dcm = dicomsdl.open(dcm_file[0])
pixel_array = dicomsdl_to_numpy_image(dcm)
pixel_array = pixel_array.astype(np.float32)
image = np.stack([pixel_array]*16)
dz,dy,dx = 1,1,1
return image, (dz,dy,dx)
dcm_file_excep = []
image = []
for f in dcm_file:
try:
dcm = dicomsdl.open(f)
pixel_array = dicomsdl_to_numpy_image(dcm)
if dcm.PixelRepresentation == 1:
bit_shift = dcm.BitsAllocated - dcm.BitsStored
dtype = pixel_array.dtype
pixel_array = (pixel_array << bit_shift).astype(dtype) >> bit_shift
#processing
pixel_array = pixel_array.astype(np.float32)
pixel_array = dcm.RescaleSlope * pixel_array + dcm.RescaleIntercept
xmin = dcm.WindowCenter-0.5-(dcm.WindowWidth-1)* 0.5
xmax = dcm.WindowCenter-0.5+(dcm.WindowWidth-1)* 0.5
norm = np.empty_like(pixel_array, dtype=np.uint8)
dicomsdl.util.convert_to_uint8(pixel_array, norm, xmin, xmax)
if dcm.PhotometricInterpretation == 'MONOCHROME1':
norm = 255 - norm
img = norm
if(img.shape[0]>512):
s1 = img[:512].sum()
s2 = img[(img.shape[0]//2)-256:(img.shape[0]//2)+256].sum()
s3 = img[-512:].sum()
if(s1>s2 and s1>s3):
img = img[:512]
elif(s3>s1 and s3>s2):
img = img[-512:]
else:
img = img[(img.shape[0]//2)-256:(img.shape[0]//2)+256]
if(img.shape[1]>512):
s1 = img[:, :512].sum()
s2 = img[:, (img.shape[0]//2)-256:(img.shape[0]//2)+256].sum()
s3 = img[:, -512:].sum()
if(s1>s2 and s1>s3):
img = img[:, :512]
elif(s3>s1 and s3>s2):
img = img[:, -512:]
else:
img = img[:, (img.shape[0]//2)-256:(img.shape[0]//2)+256]
image.append(img)
except:
dcm_file_excep.append(f)
dcm_file = [i for i in dcm_file if i not in dcm_file_excep]
#------------------------------------
if slice_range is None:
slice_min = int(dcm_file[0].split('/')[-1].split('.')[0])
slice_max = int(dcm_file[-1].split('/')[-1].split('.')[0])+1
slice_range=(slice_min, slice_max)
slice_min, slice_max = slice_range
sz0, szN = None, None
if 1: #check inversion
dcm0 = dicomsdl.open(f'{dcm_dir}/{slice_min}.dcm')
dcmN = dicomsdl.open(f'{dcm_dir}/{slice_max-1}.dcm')
sx0, sy0, sz0 = dcm0.ImagePositionPatient
sxN, syN, szN = dcmN.ImagePositionPatient
if szN > sz0:
image=image[::-1]
dx, dy = dcm0.PixelSpacing
dz = np.abs((szN - sz0) / (slice_max - slice_min-1))
image = np.stack(image)
return image, (dz,dy,dx)
def pre_process_slice_predictor(image):
l,s,s = image.shape
L,S,S = CFG.seq_len, CFG.img_size, CFG.img_size
l1 = int(S / s * l)
image = F.interpolate(
image.unsqueeze(0).unsqueeze(0),
size=[l1,S,S],
mode='trilinear'
).squeeze(0).squeeze(0)
# pad or crop to max length L
if L > l1:
image = F.pad(image, [0, 0, 0, 0, 0, L - l1], mode='constant', value=0)
if L < l1:
image = image[:L]
return image
def get_df():
df_test = pd.read_csv(os.path.join(CFG.data_dir, f'{CFG.mode}_series_meta.csv')).merge(
pd.read_csv(os.path.join(CFG.data_dir, f"{CFG.mode.replace('test', 'sample_submission')}.csv")))
df_test['path'] = CFG.img_dir + '/' + df_test['patient_id'].astype(str)+'/'+df_test['series_id'].astype(str)
df_test = df_test[df_test['path'].apply(os.path.isdir)]
df_test = df_test.reset_index(drop=True)
return df_test
class TestDataset(Dataset):
def __init__(self, csv):
self.csv = csv
def __len__(self):
return self.csv.shape[0]
def __getitem__(self, index):
row = self.csv.iloc[index]
image, (dz, dy, dx) = load_dicomsdl_dir(row.path, slice_range=None)
image = torch.from_numpy(image).float()
image = do_pad_to_square(image)
image = do_scale_to_size(image, (dz, dy, dx), max_size=CFG.img_size)
image = pre_process_slice_predictor(image)
image = image.to(torch.float16)
image /= 255
return row.patient_id, image
import torch.nn as nn
from itertools import repeat
class SpatialDropout(nn.Module):
def __init__(self, drop=0.5):
super(SpatialDropout, self).__init__()
self.drop = drop
def forward(self, inputs, noise_shape=None):
outputs = inputs.clone()
if noise_shape is None:
noise_shape = (inputs.shape[0], *repeat(1, inputs.dim()-2), inputs.shape[-1])
self.noise_shape = noise_shape
if not self.training or self.drop == 0:
return inputs
else:
noises = self._make_noises(inputs)
if self.drop == 1:
noises.fill_(0.0)
else:
noises.bernoulli_(1 - self.drop).div_(1 - self.drop)
noises = noises.expand_as(inputs)
outputs.mul_(noises)
return outputs
def _make_noises(self, inputs):
return inputs.new().resize_(self.noise_shape)
import torch
from torch import nn
import torch.nn.functional as F
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
class MLPAttentionNetwork(nn.Module):
def __init__(self, hidden_dim, attention_dim=None):
super(MLPAttentionNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.attention_dim = attention_dim
if self.attention_dim is None:
self.attention_dim = self.hidden_dim
# W * x + b
self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
# v.T
self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)
def forward(self, x):
batch_size, seq_len, _ = x.size()
H = torch.tanh(self.proj_w(x))
att_scores = torch.softmax(self.proj_v(H),axis=1)
attn_x = (x * att_scores).sum(1)
return attn_x
class RSNAClassifier(nn.Module):
def __init__(self, model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=False):
super().__init__()
self.seq_len = seq_len
self.model_arch = model_arch
self.model = timm.create_model(model_arch, in_chans=3, pretrained=pretrained)
cnn_feature = self.model.fc.in_features
self.model.global_pool = nn.Identity()
self.model.fc = nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.spatialdropout = SpatialDropout(CFG.dropout)
self.gru = nn.GRU(cnn_feature, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
self.mlp_attention_layer = MLPAttentionNetwork(2 * hidden_dim)
self.logits = nn.Sequential(
# nn.Linear(hidden_dim*2, 128),
# nn.ReLU(),
# nn.Dropout(CFG.dropout),
nn.Linear(256, len(CFG.target_cols)),
)
def forward(self, x):
bs = x.size(0)
x = x.reshape(bs*self.seq_len//3, 3, x.size(2), x.size(3))
features = self.model(x)
features = self.pooling(features).view(bs*self.seq_len//3, -1)
features = self.spatialdropout(features)
# print(features.shape)
features = features.reshape(bs, self.seq_len//3, -1)
features, _ = self.gru(features)
atten_out = self.mlp_attention_layer(features)
pred = self.logits(atten_out)
pred = pred.view(bs, -1)
return pred
class RSNAClassifier2(nn.Module):
def __init__(self, model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=False):
super().__init__()
self.seq_len = seq_len
self.model_arch = model_arch
self.model = timm.create_model(model_arch, in_chans=3, pretrained=pretrained)
cnn_feature = self.model.fc.in_features
self.model.global_pool = nn.Identity()
self.model.fc = nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.spatialdropout = SpatialDropout(CFG.dropout)
self.gru = nn.GRU(cnn_feature, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
self.mlp_attention_layer = MLPAttentionNetwork(2 * hidden_dim)
self.logits = nn.Sequential(
# nn.Linear(hidden_dim*2, 128),
# nn.ReLU(),
# nn.Dropout(CFG.dropout),
nn.Linear(256, 9),
)
def forward(self, x):
bs = x.size(0)
x = x.reshape(bs*self.seq_len//3, 3, x.size(2), x.size(3))
features = self.model(x)
features = self.pooling(features).view(bs*self.seq_len//3, -1)
features = self.spatialdropout(features)
# print(features.shape)
features = features.reshape(bs, self.seq_len//3, -1)
features, _ = self.gru(features)
atten_out = self.mlp_attention_layer(features)
pred = self.logits(atten_out)
pred = pred.view(bs, -1)
return pred
def get_preds():
test_df = get_df()
test_dataset = TestDataset(test_df)
test_loader = DataLoader(test_dataset, batch_size=CFG.valid_bs, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
cls_model_list = []
for m_arch, m_weight in zip(CFG.archs_list, CFG.weights_list):
model = RSNAClassifier(m_arch, hidden_dim=128, seq_len=CFG.seq_len, pretrained=False)
model.to(device)
model.load_state_dict(torch.load(m_weight)["model"])
model.eval()
cls_model_list.append(model)
cls_model_list2 = []
for m_arch, m_weight in zip(CFG.archs_list2, CFG.weights_list2):
model = RSNAClassifier2(m_arch, hidden_dim=128, seq_len=CFG.seq_len, pretrained=False)
model.to(device)
model.load_state_dict(torch.load(m_weight)["model"])
model.eval()
cls_model_list2.append(model)
print(len(cls_model_list), len(cls_model_list2))
all_preds = []
patient_ids = []
for step, (patient, images) in enumerate(test_loader):
images = images.to(device, dtype=torch.float)
models_preds = []
for model in cls_model_list:
with torch.no_grad():
y_preds = model(images)
y_preds = y_preds.squeeze(1)
models_preds.append(y_preds.sigmoid().to('cpu').numpy())
models_preds = np.mean(models_preds, axis=0)
models_preds2 = []
for model in cls_model_list2:
with torch.no_grad():
y_preds = model(images)
y_preds = y_preds.squeeze(1)
models_preds2.append(y_preds.sigmoid().to('cpu').numpy())
models_preds2 = np.mean(models_preds2, axis=0)
models_preds[:, 4:] = 0.5*models_preds[:, 4:] + 0.5*models_preds2
all_preds.append(models_preds)
patient_ids += patient.tolist()
all_preds = np.concatenate(all_preds)
all_preds[:, 1] *= 2
all_preds[:, 3] *= 6
all_preds[:, [5, 8, 11]] *= 2
all_preds[:, [6, 9, 12]] *= 4
del cls_model_list, model, test_dataset, test_loader, test_df
gc.collect()
torch.cuda.empty_cache()
sub = pd.DataFrame(all_preds, columns = CFG.target_cols)
sub['patient_id'] = patient_ids
sub = sub[['patient_id'] + CFG.target_cols].groupby('patient_id').mean().reset_index()
return sub
sub__ = get_preds()
/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
5 5
!pip install pydicom
!pip install nibabel
!pip install timm
!pip install transformers
!pip install albumentations
#!pip install segmentation_models_pytorch
Requirement already satisfied: pydicom in /opt/conda/lib/python3.10/site-packages (2.4.3) Requirement already satisfied: nibabel in /opt/conda/lib/python3.10/site-packages (5.1.0) Requirement already satisfied: numpy>=1.19 in /opt/conda/lib/python3.10/site-packages (from nibabel) (1.23.5) Requirement already satisfied: packaging>=17 in /opt/conda/lib/python3.10/site-packages (from nibabel) (21.3) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=17->nibabel) (3.0.9) Requirement already satisfied: timm in /opt/conda/lib/python3.10/site-packages (0.9.7) Requirement already satisfied: torch>=1.7 in /opt/conda/lib/python3.10/site-packages (from timm) (2.0.0) Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (from timm) (0.15.1) Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from timm) (6.0) Requirement already satisfied: huggingface-hub in /opt/conda/lib/python3.10/site-packages (from timm) (0.16.4) Requirement already satisfied: safetensors in /opt/conda/lib/python3.10/site-packages (from timm) (0.3.3) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.7->timm) (3.12.2) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch>=1.7->timm) (4.6.3) Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.7->timm) (1.12) Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.7->timm) (3.1) Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.7->timm) (3.1.2) Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->timm) (2023.9.0) Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->timm) (2.31.0) Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->timm) (4.66.1) Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->timm) (21.3) Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision->timm) (1.23.5) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.10/site-packages (from torchvision->timm) (9.5.0) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.9->huggingface-hub->timm) (3.0.9) Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.7->timm) (2.1.3) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->timm) (3.1.0) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->timm) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->timm) (1.26.15) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->timm) (2023.7.22) Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.7->timm) (1.3.0) Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (4.33.0) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers) (3.12.2) Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.16.4) Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (1.23.5) Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers) (21.3) Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (6.0) Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (2023.6.3) Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers) (2.31.0) Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.13.3) Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.3.3) Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers) (4.66.1) Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.9.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.6.3) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->transformers) (3.0.9) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.1.0) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (1.26.15) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2023.7.22) Requirement already satisfied: albumentations in /opt/conda/lib/python3.10/site-packages (1.3.1) Requirement already satisfied: numpy>=1.11.1 in /opt/conda/lib/python3.10/site-packages (from albumentations) (1.23.5) Requirement already satisfied: scipy>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from albumentations) (1.11.2) Requirement already satisfied: scikit-image>=0.16.1 in /opt/conda/lib/python3.10/site-packages (from albumentations) (0.21.0) Requirement already satisfied: PyYAML in /opt/conda/lib/python3.10/site-packages (from albumentations) (6.0) Requirement already satisfied: qudida>=0.0.4 in /opt/conda/lib/python3.10/site-packages (from albumentations) (0.0.4) Requirement already satisfied: opencv-python-headless>=4.1.1 in /opt/conda/lib/python3.10/site-packages (from albumentations) (4.8.0.76) Requirement already satisfied: scikit-learn>=0.19.1 in /opt/conda/lib/python3.10/site-packages (from qudida>=0.0.4->albumentations) (1.2.2) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from qudida>=0.0.4->albumentations) (4.6.3) Requirement already satisfied: networkx>=2.8 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (3.1) Requirement already satisfied: pillow>=9.0.1 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (9.5.0) Requirement already satisfied: imageio>=2.27 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (2.31.1) Requirement already satisfied: tifffile>=2022.8.12 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (2023.4.12) Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (1.4.1) Requirement already satisfied: packaging>=21 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (21.3) Requirement already satisfied: lazy_loader>=0.2 in /opt/conda/lib/python3.10/site-packages (from scikit-image>=0.16.1->albumentations) (0.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=21->scikit-image>=0.16.1->albumentations) (3.0.9) Requirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (3.1.0)
import sys
sys.path.append('/kaggle/input/segmentation-library/segmentation_models.pytorch')
sys.path.append('/kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4')
sys.path.append('/kaggle/input/efficientnet-library/EfficientNet-PyTorch')
import pandas as pd
import numpy as np
import cv2
import zipfile
import os
import gc
import glob
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm
import pydicom #as dicom
import nibabel as nib
import albumentations as A
import torch
import torch.nn as nn
from torch.nn import functional as F
import timm
from transformers import RobertaPreLayerNormConfig, RobertaPreLayerNormModel
from segmentation_models_pytorch.decoders.unet.model import (
UnetDecoder,
SegmentationHead,
)
device = 'cuda'
/kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels/models/dpn.py:255: SyntaxWarning: "is" with a literal. Did you mean "=="? if block_type is 'proj': /kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels/models/dpn.py:258: SyntaxWarning: "is" with a literal. Did you mean "=="? elif block_type is 'down': /kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels/models/dpn.py:262: SyntaxWarning: "is" with a literal. Did you mean "=="? assert block_type is 'normal'
class SegModel(nn.Module):
def __init__(self):
super(SegModel, self).__init__()
self.n_classes = len(
[
'background',
'liver',
'spleen',
'left kidney',
'right kidney',
'bowel'
])
in_chans = 1
self.encoder = timm.create_model(
'regnety_002',
pretrained=False,
features_only=True,
in_chans=in_chans,
)
encoder_channels = tuple(
[in_chans]
+ [
self.encoder.feature_info[i]["num_chs"]
for i in range(len(self.encoder.feature_info))
]
)
self.decoder = UnetDecoder(
encoder_channels=encoder_channels,
decoder_channels=(256, 128, 64, 32, 16),
n_blocks=5,
use_batchnorm=True,
center=False,
attention_type=None,
)
self.segmentation_head = SegmentationHead(
in_channels=16,
out_channels=self.n_classes,
activation=None,
kernel_size=3,
)
self.bce_seg = nn.BCEWithLogitsLoss()
def forward(self, x_in):
enc_out = self.encoder(x_in)
decoder_out = self.decoder(*[x_in] + enc_out)
x_seg = self.segmentation_head(decoder_out)
return nn.Sigmoid()(x_seg)
class SegPipeline(nn.Module):
def __init__(self):
super(SegPipeline, self).__init__()
self.seg_model = SegModel().cuda()
self.seg_model.eval()
checkpoint = torch.load('/kaggle/input/rsna-atd-segmentation-weights/epoch040-trainloss0.0049-valloss0.0068.bin')
self.seg_model.load_state_dict(checkpoint)
self.seg_model.eval()
self.batch_size = 128
@torch.no_grad()
def forward(self, x_in):
x_in = x_in.cuda()
segmentations = []
for i in range(0, x_in.shape[0], self.batch_size):
segmentations.append(self.seg_model(x_in[i:i+self.batch_size]))
return torch.cat(segmentations).cpu()
seg_pipeline = SegPipeline()
from transformers import RobertaPreLayerNormConfig, RobertaPreLayerNormModel
class FeatureExtractor(nn.Module):
def __init__(self, hidden, num_channel):
super(FeatureExtractor, self).__init__()
self.hidden = hidden
self.num_channel = num_channel
self.cnn = timm.create_model(model_name = 'regnety_002',
pretrained = False,
num_classes = 0,
in_chans = num_channel)
self.fc = nn.Linear(hidden, hidden//2)
def forward(self, x):
batch_size, num_frame, h, w = x.shape
x = x.reshape(batch_size, num_frame//self.num_channel, self.num_channel, h, w)
x = x.reshape(-1, self.num_channel, h, w)
x = self.cnn(x)
x = x.reshape(batch_size, num_frame//self.num_channel, self.hidden)
x = self.fc(x)
return x
class ContextProcessor(nn.Module):
def __init__(self, hidden):
super(ContextProcessor, self).__init__()
self.transformer = RobertaPreLayerNormModel(
RobertaPreLayerNormConfig(
hidden_size = hidden//2,
num_hidden_layers = 1,
num_attention_heads = 4,
intermediate_size = hidden*2,
hidden_act = 'gelu_new',
)
)
del self.transformer.embeddings.word_embeddings
self.dense = nn.Linear(hidden, hidden)
self.activation = nn.ReLU()
def forward(self, x):
x = self.transformer(inputs_embeds = x).last_hidden_state
apool = torch.mean(x, dim = 1)
mpool, _ = torch.max(x, dim = 1)
x = torch.cat([mpool, apool], dim = -1)
x = self.dense(x)
x = self.activation(x)
return x
class CustomModel(nn.Module):
def __init__(self, hidden = 368, num_channel = 2):
super(CustomModel, self).__init__()
self.full_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
self.kidney_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
self.liver_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
self.spleen_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
self.full_processor = ContextProcessor(hidden=hidden)
self.kidney_processor = ContextProcessor(hidden=hidden)
self.liver_processor = ContextProcessor(hidden=hidden)
self.spleen_processor = ContextProcessor(hidden=hidden)
self.bowel = nn.Linear(hidden, 2)
self.extravasation = nn.Linear(hidden, 2)
self.kidney = nn.Linear(hidden, 3)
self.liver = nn.Linear(hidden, 3)
self.spleen = nn.Linear(hidden, 3)
self.softmax = nn.Softmax(dim = -1)
def forward(self, full_input, crop_liver, crop_spleen, crop_kidney):
full_output = self.full_extractor(full_input)
kidney_output = self.kidney_extractor(crop_kidney)
liver_output = self.liver_extractor(crop_liver)
spleen_output = self.spleen_extractor(crop_spleen)
full_output2 = self.full_processor(torch.cat([full_output, kidney_output, liver_output, spleen_output], dim = 1))
kidney_output2 = self.kidney_processor(torch.cat([full_output, kidney_output], dim = 1))
liver_output2 = self.liver_processor(torch.cat([full_output, liver_output], dim = 1))
spleen_output2 = self.spleen_processor(torch.cat([full_output, spleen_output], dim = 1))
bowel = self.softmax(self.bowel(full_output2))[0].tolist()
extravasation = self.softmax(self.extravasation(full_output2))[0].tolist()
kidney = self.softmax(self.kidney(kidney_output2))[0].tolist()
liver = self.softmax(self.liver(liver_output2))[0].tolist()
spleen = self.softmax(self.spleen(spleen_output2))[0].tolist()
#any_injury = torch.stack([
# self.softmax(bowel)[:, 0],
# self.softmax(extravasation)[:, 0],
# self.softmax(kidney)[:, 0],
# self.softmax(liver)[:, 0],
# self.softmax(spleen)[:, 0]
#], dim = -1)
#any_injury = 1 - any_injury
#if mode == 'train':
# mask = mask[:, [0, 2, 4, 7, 10]]
# assert any_injury.shape == mask.shape
# any_injury = any_injury * mask
# any_injury = any_injury.sum(1) / mask.sum(1)
# any_injury
#else:
# any_injury, _ = any_injury.max(1)
return bowel, extravasation, kidney, liver, spleen#, any_injury
model1 = CustomModel()
weights1 = torch.load('/kaggle/input/rsna-atd-weights/25dcnn-channel2-512-1thfold-trainloss0.165-valloss0.539.bin')
model1 = model1.to(device)
model1.load_state_dict(weights1)
model2 = CustomModel()
weights2 = torch.load('/kaggle/input/rsna-atd-weights/25dcnn-channel2-512-2thfold-trainloss0.2092-valloss0.4989.bin')
model2 = model2.to(device)
model2.load_state_dict(weights2)
model3 = CustomModel()
weights3 = torch.load('/kaggle/input/rsna-atd-weights/25dcnn-channel2-512-3thfold-trainloss0.186-valloss0.4705.bin')
model3 = model3.to(device)
model3.load_state_dict(weights3)
model4 = CustomModel()
weights4 = torch.load('/kaggle/input/rsna-atd-weights/25dcnn-channel2-512-4thfold-trainloss0.1983-valloss0.4716.bin')
model4 = model4.to(device)
model4.load_state_dict(weights4)
model5 = CustomModel()
weights5 = torch.load('/kaggle/input/rsna-atd-weights/25dcnn-channel2-512-5thfold-trainloss0.1559-valloss0.5194.bin')
model5 = model5.to(device)
model5.load_state_dict(weights5)
<All keys matched successfully>
# weight per model
class CustomModelEnsemble(nn.Module):
def __init__(self, models):
super(CustomModelEnsemble, self).__init__()
self.models = models
'''
self.weight = np.array(
[
[0.5, 1, 1, 2, 1, 3, 2, 5],
[0.2, 2, 2, 2, 1, 2, 4, 3],
[0.2, 1, 3, 1, 2, 3, 2, 6],
[0.2, 2, 2, 1, 1, 4, 4, 4],
[0.8, 1, 2, 2, 1, 1, 5, 4]
]
)
'''
self.weight = np.array(
[
[0.9, 4, 2, 4, 2, 6, 6, 6],
[0.9, 1, 4, 3, 2, 5, 5, 6],
[0.2, 3, 2, 1, 2, 4, 2, 6],
[0.5, 2, 2, 2, 2, 2, 6, 6],
[1, 2, 3, 2, 6, 3, 6, 5]
]
)
for model in self.models:
model.eval()
def forward(self, x, crop_liver, crop_spleen, crop_kidney):
bowels = []
extravasations = []
kidneys = []
livers = []
spleens = []
for model in self.models:
bowel, extravasation, kidney, liver, spleen = model(x, crop_liver, crop_spleen, crop_kidney)
bowels.append(bowel)
extravasations.append(extravasation)
kidneys.append(kidney)
livers.append(liver)
spleens.append(spleen)
bowels = np.array(bowels)
extravasations = np.array(extravasations)
kidneys = np.array(kidneys)
livers = np.array(livers)
spleens = np.array(spleens)
bowels[:, 1] *= self.weight[:, 0]
extravasations[:, 1] *= self.weight[:, 7]
kidneys[:, 1] *= self.weight[:, 1]
kidneys[:, 2] *= self.weight[:, 4]
livers[:, 1] *= self.weight[:, 2]
livers[:, 2] *= self.weight[:, 5]
spleens[:, 1] *= self.weight[:, 3]
spleens[:, 2] *= self.weight[:, 6]
bowels = bowels ** 0.8
extravasations = extravasations ** 0.8
kidneys = kidneys ** 0.8
livers = livers ** 0.8
spleens = spleens ** 0.8
bowel = np.mean(bowels, axis=0)
extravasation = np.mean(extravasations, axis=0)
kidney = np.mean(kidneys, axis=0)
liver = np.mean(livers, axis=0)
spleen = np.mean(spleens, axis=0)
return bowel.tolist(), extravasation.tolist(), kidney.tolist(), liver.tolist(), spleen.tolist() #, any_injury
models = [model1, model2, model3, model4, model5]
model = CustomModelEnsemble(models)
# utils
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
"""
Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
"""
# Correct DICOM pixel_array if PixelRepresentation == 1.
pixel_array = dcm.pixel_array
if dcm.PixelRepresentation == 1:
bit_shift = dcm.BitsAllocated - dcm.BitsStored
dtype = pixel_array.dtype
pixel_array = (pixel_array << bit_shift).astype(dtype) >> bit_shift
# pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)
intercept = float(dcm.RescaleIntercept)
slope = float(dcm.RescaleSlope)
center = int(dcm.WindowCenter)
width = int(dcm.WindowWidth)
low = center - width / 2
high = center + width / 2
pixel_array = (pixel_array * slope) + intercept
pixel_array = np.clip(pixel_array, low, high)
return pixel_array
def get_high_aortic_hu(df):
patient_ids = df.patient_id.unique()
high_aortic_hu_df = []
for i in range(len(patient_ids)):
patient_id = patient_ids[i]
sample = df.query(f'patient_id=={patient_id}').sort_values('aortic_hu', ascending=False).reset_index(drop=True)
sample = sample.loc[0]
high_aortic_hu_df.append(sample)
high_aortic_hu_df = pd.concat(high_aortic_hu_df, axis=1).transpose().reset_index(drop=True)
high_aortic_hu_df = high_aortic_hu_df.astype('int32')
return high_aortic_hu_df
# data process
image_height, image_width = 320, 320
transform_function = A.Resize(image_height, image_width)
def get_stride_box(min_y, min_x, max_y, max_x, stride=10):
min_y = np.clip(min_y - stride, a_min=0, a_max=512)
min_x = np.clip(min_x - stride, a_min=0, a_max=512)
max_y = np.clip(max_y + stride, a_min=0, a_max=512)
max_x = np.clip(max_x + stride, a_min=0, a_max=512)
return min_y, min_x, max_y, max_x
def get_segmentation_inputs(video, image_height=320, image_width=320, transform_function = transform_function):
video = video.transpose(1, 2, 0)
video = cv2.resize(video, dsize=(int(image_height), int(image_width)))
video = video.astype(np.uint8).transpose(2, 0, 1)
transforms = [transform_function(image=video[i]) for i in range(video.shape[0])]
video = np.stack([x['image'] for x in transforms], axis=0)
video = torch.tensor(video, dtype=torch.float)
video = video / 255
video = video[:, None]
return video
def get_coordinates(logit, thres=0.5):
liver = (logit[:, 0]>thres).float()
spleen = (logit[:, 1]>thres).float()
left_kidney = (logit[:, 2]>thres).float()
right_kidney = (logit[:, 3]>thres).float()
organs = [liver, spleen, left_kidney+right_kidney]
coordinates = []
for organ in organs:
ones_coordinates = np.argwhere(organ!=0)
min_z, min_y, min_x = ones_coordinates.min(axis=1).values
max_z, max_y, max_x = ones_coordinates.max(axis=1).values
coordinate = [min_z, min_y, min_x, max_z, max_y, max_x]
coordinate = torch.tensor(coordinate)
coordinates.append(coordinate)
coordinates = torch.stack(coordinates).numpy()
return coordinates
def logit2box(seg_output):
try:
coordinates = get_coordinates(seg_output, thres=0.5)
except:
try:
#print(f'{patient_id}-{series_id} has no logit for thres=0.5')
coordinates = get_coordinates(seg_output, thres=0.1)
except:
#print(f'{patient_id}-{series_id} has no logit for thres=0.1')
coordinates = np.zeros([3, 6])
return coordinates
def get_cropped_organs(video, box, ratio=(512/320)):
organs = []
for i in range(box.shape[0]):
min_z, min_y, min_x, max_z, max_y, max_x = box[i]
if 0.0 not in [max_z - min_z, max_y - min_y, max_x - min_x]:
min_y, min_x, max_y, max_x = int(ratio*min_y), int(ratio*min_x), int(ratio*max_y), int(ratio*max_x)
min_y, min_x, max_y, max_x = get_stride_box(min_y, min_x, max_y, max_x)
organ = video[min_z:max_z, min_y:max_y, min_x:max_x]
else:
organ = video
organ = F.interpolate(
organ.unsqueeze(0).unsqueeze(0),
size=[96, 224, 224],
mode='trilinear'
).squeeze(0).squeeze(0)
organs.append(organ)
return organs
def get_crop_inputs(video, box):
video = torch.tensor(video, dtype=torch.float)
organs = get_cropped_organs(video, box)
crop_liver, crop_spleen, crop_kidney = organs
video = F.interpolate(
video.unsqueeze(0).unsqueeze(0),
size=[128, 224, 224],
mode='trilinear'
).squeeze(0).squeeze(0)
video = torch.tensor(video.numpy().astype(np.uint8), dtype=torch.float)
crop_liver = torch.tensor(crop_liver.numpy().astype(np.uint8), dtype=torch.float)
crop_spleen = torch.tensor(crop_spleen.numpy().astype(np.uint8), dtype=torch.float)
crop_kidney = torch.tensor(crop_kidney.numpy().astype(np.uint8), dtype=torch.float)
video, crop_liver, crop_spleen, crop_kidney = video/255.0, crop_liver/255.0, crop_spleen/255.0, crop_kidney/255.0
return video, crop_liver, crop_spleen, crop_kidney
import torch.nn.functional as F
img_h, img_w = 512, 512
down_sampling = 1#4
max_frame = 1024#256
model.eval()
test_series_meta = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_series_meta.csv')
#test_series_meta = get_high_aortic_hu(test_series_meta)
bowel_healthy = []
bowel_injury = []
extravasation_healthy = []
extravasation_injury = []
kidney_healthy = []
kidney_low = []
kidney_high = []
liver_healthy = []
liver_low = []
liver_high = []
spleen_healthy = []
spleen_low = []
spleen_high = []
for i in tqdm(range(len(test_series_meta))):
patient_id, series_id = test_series_meta.loc[i]['patient_id'], test_series_meta.loc[i]['series_id']
path = f'/kaggle/input/rsna-2023-abdominal-trauma-detection/test_images/{int(patient_id)}/{int(series_id)}/'
dcm_paths = glob.glob(path + '*.dcm')
imgs = {}
for f in sorted(dcm_paths, key = lambda x : x.split('/')[-1].split('.')[0]):
try:
dicom = pydicom.dcmread(f)
pos_z = dicom[(0x20, 0x32)].value[-1]
img = standardize_pixel_array(dicom)
img = (img - img.min()) / (img.max() - img.min() + 1e-6)
if dicom.PhotometricInterpretation == "MONOCHROME1":
img = 1 - img
img = cv2.resize((img * 255).astype(np.uint8), dsize=(img_h, img_w))
except:
img = np.zeros([img_h, img_w])
imgs[pos_z] = img
video = []
seg_video = []
for i, k in enumerate(sorted(imgs.keys())):
video.append(cv2.resize(imgs[k], dsize=(512, 512)))
seg_video.append(cv2.resize(imgs[k], dsize=(224, 224)))
del imgs
if len(video) > 0:
video = np.stack(video)
seg_video = np.stack(seg_video)
else:
video = np.zeros([1, 512, 512])
seg_video = np.zeros([1, 224, 224])
video = video[::down_sampling][:max_frame]
video = torch.tensor(video, dtype=torch.float)
seg_video = seg_video[::down_sampling][:max_frame]
seg_video = torch.tensor(seg_video, dtype=torch.float)
seg_video = F.interpolate(
seg_video.unsqueeze(0).unsqueeze(0),
size=[256, 224, 224],
mode='trilinear'
).squeeze(0).squeeze(0)
seg_video = seg_video.numpy()
seg_video = seg_video.astype(np.uint8)
seg_inputs = get_segmentation_inputs(seg_video)
seg_logit = seg_pipeline(seg_inputs.to(device)).cpu()
box = logit2box(seg_logit)
video = F.interpolate(
video.unsqueeze(0).unsqueeze(0),
size=[256, 512, 512],
mode='trilinear'
).squeeze(0).squeeze(0)
video = video.numpy()
#video = video.astype(np.uint8)
crop_inputs = get_crop_inputs(video, box)
video, crop_liver, crop_spleen, crop_kidney = [x[None].to(device) for x in crop_inputs]
with torch.no_grad():
pred = model(video, crop_liver, crop_spleen, crop_kidney)
bowel_healthy.append(pred[0][0])
bowel_injury.append(pred[0][1])
extravasation_healthy.append(pred[1][0])
extravasation_injury.append(pred[1][1])
kidney_healthy.append(pred[2][0])
kidney_low.append(pred[2][1])
kidney_high.append(pred[2][2])
liver_healthy.append(pred[3][0])
liver_low.append(pred[3][1])
liver_high.append(pred[3][2])
spleen_healthy.append(pred[4][0])
spleen_low.append(pred[4][1])
spleen_high.append(pred[4][2])
test_series_meta['bowel_healthy'] = bowel_healthy
test_series_meta['bowel_injury'] = bowel_injury
test_series_meta['extravasation_healthy'] = extravasation_healthy
test_series_meta['extravasation_injury'] = extravasation_injury
test_series_meta['kidney_healthy'] = kidney_healthy
test_series_meta['kidney_low'] = kidney_low
test_series_meta['kidney_high'] = kidney_high
test_series_meta['liver_healthy'] = liver_healthy
test_series_meta['liver_low'] = liver_low
test_series_meta['liver_high'] = liver_high
test_series_meta['spleen_healthy'] = spleen_healthy
test_series_meta['spleen_low'] = spleen_low
test_series_meta['spleen_high'] = spleen_high
test_series_meta = test_series_meta.drop(columns = ['series_id', 'aortic_hu'])
test_series_meta = test_series_meta.groupby('patient_id').mean().reset_index()
#test_series_meta['bowel_injury'] *= 0.2
#test_series_meta['kidney_low'] *= 1
#test_series_meta['liver_low'] *= 2
#test_series_meta['spleen_low'] *= 1
#test_series_meta['kidney_high'] *= 1
#test_series_meta['liver_high'] *= 3
#test_series_meta['spleen_high'] *= 3
#test_series_meta['extravasation_injury'] *= 4
test_series_meta
100%|██████████| 6/6 [00:32<00:00, 5.49s/it]
| patient_id | bowel_healthy | bowel_injury | extravasation_healthy | extravasation_injury | kidney_healthy | kidney_low | kidney_high | liver_healthy | liver_low | liver_high | spleen_healthy | spleen_low | spleen_high | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 48843 | 0.746932 | 0.251207 | 0.794182 | 1.288614 | 0.894185 | 0.191874 | 0.221785 | 0.880342 | 0.378936 | 0.108251 | 0.921701 | 0.208414 | 0.131717 |
| 1 | 50046 | 0.757264 | 0.240841 | 0.706023 | 1.683516 | 0.802744 | 0.368841 | 0.255433 | 0.888176 | 0.359339 | 0.113470 | 0.756671 | 0.197676 | 0.812473 |
| 2 | 63706 | 0.873660 | 0.109681 | 0.835094 | 1.098972 | 0.873895 | 0.263539 | 0.203754 | 0.836620 | 0.500140 | 0.163468 | 0.644048 | 0.296225 | 1.142041 |
# normilize the predictions before ensembling
for cols in [CFG.target_cols[:2], CFG.target_cols[2:4], CFG.target_cols[4:7],
CFG.target_cols[7:10], CFG.target_cols[10:13]]:
test_series_meta[cols] = test_series_meta[cols].div(test_series_meta[cols].sum(1), axis=0)
sub__[cols] = sub__[cols].div(sub__[cols].sum(1), axis=0)
sub__ = pd.concat([sub__, test_series_meta[~test_series_meta['patient_id'].isin(sub__['patient_id'])]])
test_series_meta = pd.concat([test_series_meta, sub__[~sub__['patient_id'].isin(test_series_meta['patient_id'])]])
test_series_meta[CFG.target_cols] *= 0.65
sub__[CFG.target_cols] *= 0.35
# cols1 = ['bowel_healthy', 'bowel_injury', 'extravasation_healthy', 'extravasation_injury']
# cols2 = ['kidney_healthy', 'kidney_low', 'kidney_high',
# 'liver_healthy', 'liver_low', 'liver_high',
# 'spleen_healthy', 'spleen_low', 'spleen_high']
# test_series_meta[cols1] *= 0.5
# sub__[cols2] *= 0.5
# test_series_meta[cols1] *= 0.8
# sub__[cols2] *= 0.2
test_series_meta = pd.concat([test_series_meta, sub__])
test_series_meta = test_series_meta.groupby('patient_id').sum().reset_index()
sample_sub = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/sample_submission.csv')
sample_sub = sample_sub[['patient_id']].merge(test_series_meta, how = 'left').fillna(0.33)
sample_sub.to_csv('submission.csv', index = False)
sample_sub
| patient_id | bowel_healthy | bowel_injury | extravasation_healthy | extravasation_injury | kidney_healthy | kidney_low | kidney_high | liver_healthy | liver_low | liver_high | spleen_healthy | spleen_low | spleen_high | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 48843 | 0.720865 | 0.279135 | 0.403578 | 0.596422 | 0.587471 | 0.199880 | 0.212649 | 0.638393 | 0.250240 | 0.111367 | 0.650003 | 0.195635 | 0.154363 |
| 1 | 50046 | 0.730922 | 0.269077 | 0.346868 | 0.653132 | 0.516994 | 0.267532 | 0.215475 | 0.649363 | 0.238537 | 0.112100 | 0.452143 | 0.160972 | 0.386886 |
| 2 | 63706 | 0.824630 | 0.175370 | 0.441758 | 0.558242 | 0.571903 | 0.228067 | 0.200030 | 0.583005 | 0.289299 | 0.127695 | 0.375879 | 0.181796 | 0.442325 |